-
Notifications
You must be signed in to change notification settings - Fork 460
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Sparse tensors #1998
base: main
Are you sure you want to change the base?
Sparse tensors #1998
Conversation
2a5c356
to
790ed5f
Compare
{ | ||
type SparseTensorPrimitive<const D: usize> = SparseCOOTensor<B, D>; | ||
|
||
fn sparse_to_sparse<const D: usize>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should this be dense_to_sparse
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I agree with you. Really it should be sparse_dense_to_sparse
based on the naming convention (right now its sparse_to_dense
to represent the to_dense
function on a sparse
tensor)... but that's very verbose and/or confusing and I like dense_to_sparse
better.
Given everything is broken out into its own crate/extension I could probably just drop the sparse_
at the start of some of these though? It originally came from the other tensor kinds using this naming scheme before I broke everything out into it's own crate.
I've finally done it... I think I landed on a good sparse API. 😪 I have spent a lot of time agonizing over how to make sparse tensors fit into the API nicely. Every solution I found felt like it was compromising on something, so I eventually settled on something that did the job, but had some inconveniences in the user facing API, and really annoying delegation issues... but no more! Now I have a solution that leaves me satisfied (though it did push my knowledge of rusts type system to its limits, which might not be the best sign 👀 ). // Tensor representations:
let sparse_tensor = Tensor::<B, 1, Float, Sparse<COO, B>>::empty(...);
let dense_tensor = Tensor::<B, 1, Float, Dense>::empty(...);
let dense_tensor = Tensor::<B, 1, Float>::empty(...);
// And other types work fine too:
let sparse_tensor = Tensor::<B, 1, Bool, Sparse<COO, B>>::empty(...);
let sparse_tensor = Tensor::<B, 1, Int, Sparse<COO, B>>::empty(...); Id appreciate some feedback on the implementation as it is now (specifically, the changes in burn-tensor, burn-sparse is old and to be renovated still). I think this is a solution that doesn't compromise on anything and requires fairly minimal amounts of rewriting (it compiles so far). The general idea, is it adds a representation generic which defaults to |
Sorry I haven't had the time to take a good look at this, I'll make time to give you some feedback at the beginning of next week 🙂 |
Very appreciated. Ive spent some time cleaning things up a lot, so honestly its probably for the best you didn't have time until now 😄. I threw together a quick mock version of what I'm doing. Given it helped me figure everything out, it will surely also be helpful to you guys navigating all the changed files easier. Im not yet set on the names (they don't feel right to me), but the final representation I settled on has the trait TensorKind<Backend> {
type Primitive;
}
trait TensorStorage<Backend> {}
struct Float;
struct Bool;
struct Int;
struct Dense;
struct Sparse<B: Backend, SR: SparseStorage<B>> {
p: PhantomData<(B, SR)>,
} An important part of the above is that we have /// This lets us offload the representation details to the backend
/// e.g. some might support COO, CSC, etc
trait SparseStorage<B> {
type Primitive<K: TensorKind<B>>;
} We can then implement all of these things as usual. Folded because its not really interesting.// Tensorkind implementations as usual
impl<B: Backend> TensorKind<B> for Float {
type Primitive = B::FloatPrimitive;
}
impl<B: Backend> TensorKind<B> for Bool {
type Primitive = B::BoolPrimitive;
}
impl<B: Backend> TensorKind<B> for Int {
type Primitive = B::IntPrimitive;
}
// Storage implementations
impl<B: Backend> TensorStorage<B> for Dense {}
impl<B: Backend, SR: SparseStorage<B>> TensorStorage<B> for Sparse<B, SR> {} Now the part that gives me the most hesitancy as to whether this is a good approach is where we combine the two to get the correct primimtive. That is done using the trait TensorRepr {
type Primitive;
}
impl<B: Backend, K: TensorKind<B>> TensorRepr for (B, K, Dense) {
type Primitive = K::Primitive;
}
impl<B: Backend, K: TensorKind<B>, SR: SparseStorage<B>> TensorRepr for (B, K, Sparse<B, SR>) {
type Primitive = SR::Primitive<K>;
}
// Then when we define tensors:
struct Tensor<B: Backend, K: TensorKind<B> = Float, S: TensorStorage<B> = Dense>
where
(B, K, S): TensorRepr,
{
primitive: <(B, K, S) as TensorRepr>::Primitive,
} Because rust has no specialisation (or meaningful negative trait bounds), and because of how I've broken up the kinds and representations, we get somewhat infectious The mockup above is describing my changes to Something like below will compile, if you want to test things: Codeuse burn::backend::ndarray::NdArrayDevice;
use burn::backend::sparse::COO;
use burn::backend::NdArray;
use burn::prelude::*;
fn main() {
type B = NdArray;
let device = NdArrayDevice::Cpu;
let dense_tensor = Tensor::<B, 2>::from_floats(
[
[1.0, 0.0, 0.0, 0.0],
[0.0, 1.0, 0.0, 0.0],
[0.0, 0.0, 1.0, 0.0],
[0.0, 0.0, 0.0, 1.0],
],
&device,
);
let sddmm_l = Tensor::<B, 2>::from_floats([[1, 1, 1, 1]], &device).transpose();
let sddmm_r = Tensor::<B, 2>::from_floats([[1, 2, 3, 4]], &device);
let sparse_tensor = Tensor::<B, 2>::from_floats(
[
[0.0, 0.0, 0.1, 1.0],
[4.0, 0.0, 2.0, 0.0],
[2.5, 0.0, 0.0, 0.0],
[0.0, 8.0, 0.0, 1.0],
],
&device,
)
.into_sparse::<COO>();
// let multiplied = sparse_tensor.clone().spmm(dense_tensor.clone());
let spmm = sparse_tensor.clone().spmm(dense_tensor.clone());
let sddmm = sparse_tensor.clone().sddmm(sddmm_l, sddmm_r);
println!("{}", spmm);
println!("{}", sddmm.into_dense());
let sparse_tensor = sparse_tensor.unsqueeze::<3>().repeat_dim(0, 2);
println!("{}", sparse_tensor.into_dense());
} |
First of all, great work so far! It seems like you've come a long way 👏 Regarding the implementation, I went through the suggested changes in the PR a couple of times and I'm not entirely convinced yet with the added traits required just to add a sparse tensor primitive. While the user facing API remains essentially unchanged (very nice!), at this stage it feels like the implementation is only halfway to a backend decorator. I do not necessarily mind the additional I'll come back to this draft later this week to see if my mind has changed 😅 and otherwise try to come up with alternatives to help out. |
This is pretty much how I feel about it. It is the "best" solution I've found (no breaking changes, quite extensible, lets backends choose how to implement), I just wish the implementation could be simpler... It feels convoluted. Not convoluted enough i couldn't live with it, but enough that I hope there is a better way. And yes, it feels like a very round-about way of doing what a decorator would. There were some issues I ran into with the decorator + backend extension combo unfortunately.
Since I made my last comment I have begun to suspect it might be possible to combine the backend extension + representation approaches together into a happy middle ground? Rip all the traits/functionality out of tensor storage and put them in a backend extension which takes the TensorStorage as a generic. Then implement the sparse API on Instead of a decorator, I could just do a blanket implementation of this extension (for the decorator COO SparseStorage only at the moment) on all backends (not sure if this will work). That would eliminate the delegation issues. Individual backends can implement the extension with a different TensorStorage type later, if they want to support their own/more kinds. Sorry if that was a bit of a mess of words, its 3am as I write this 😅. I definitely intend to explore that approach soonish. For now though I want to get all of the operations implemented for the decorator so that it is complete (because changing the API doesn't invalidate that code, luckily). Gives you time to think on it too, while I do that. |
@McArthur-Alford Thank you for sharing the issues you ran into and your thought process! It helps to know why some decisions were made a bit more specifically. To give a more thorough review with appropriate suggestions I feel like I would have to dive into it a bit more to iterate over some design choices (which I would like to do). We just released today so I should be able to reserve some time for that in the following week 🤞
I think this would alleviate my biggest pain point with the current implementation. I feel like
Yes please go ahead 🙏 don't want to slow you down hehe |
This PR has been marked as stale because it has not been updated for over a month |
Apologies bot. I am in fact still alive, just got incredibly busy the last few weeks. Should be back to this soonish. |
No worries, I have been quite busy as well 😅 And don't worry about the conflicts for now, we just refactored the tensor ops to remove the const D generic for primitives. |
This PR has been marked as stale because it has not been updated for over a month |
Im still not dead, I promise. Probably about a week or two away from having the time to actually go through and get this thing done. |
This PR has been marked as stale because it has not been updated for over a month |
Pull Request Template
This is a draft PR for sparse tensor support. It is probably still quite a ways away from being finished in my eyes, however I wanted to make this early to get some eyes on the matter throughout the process. Because of the nature of this draft, I haven't really done much documenting/testing at all, and it will probably fail many tests. Ill get around to those problems once the code is stable.
Right now there isn't a lot of functionality (though spmm for the COO tensor is working nicely). The goal at the moment is to get feedback on the overall architectural choices. I'm pretty happy, but there are some things I suspect could be done better/differently, and Id like to make those kinds of changes before bulk implementing functionality.
Checklist
run-checks all
script has been executed.Related Issues/PRs
Changes
My motivation/end goal is to do GNNs. As a result, some of the things I work on might be skewed towards that. If there are other sparse tensor things that you might want, tell me.
What are the main changes?
SparseBackend
backend extension for sparse functionality. It will contain most of the typical tensor ops you would find inInt/Bool/FloatTensorOps
, plus some special sparse ones likespmm
. Certain operations are not easily achievable with sparse tensors and require conversion back to dense. Due to the potential performance costs of this, I am intentionally not adding ops, instead requiring the user to deliberately convert back to dense.Sparse
tensor kind and the associated tensor API functionality (but only for tensors whereB: SparseBackend
). Right now,Sparse
andSparseBackend
are in the burn-tensor crate, though I am debating moving them into their own crate.burn-sparse
crate. This provides aSparseDecorator
struct that implementsSparseBackend
. It also provides support for multiple representations of sparse tensor. The point of this is to get all backends supporting sparse tensors quickly and easily, at the cost of not having custom kernels, etc. My goals are probably to just get COO, CSR and CSC working before I'm happy with it.Some things I haven't done but intend to address/am seeking ideas on:
TensorData
struct. I haven't even looked at this yet, but it seems like it is built purely for dense tensors. Given sparse tensors can have quite different representations, so I'm not sure wrapping it in an enum would make sense. Unfortunately, I have to do something with this, because It feels like a bad solution to make users pull the coordinates/values or other representation specific details out of their tensor without an API.Testing
I've done no testing so far, beyond a little demo. If you want to get a sense for the api, the below code should run: